mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp829' into scaled_adam_exp860
# Conflicts: # egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
This commit is contained in:
commit
9b0c0aabb2
@ -406,6 +406,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# to work correctly.
|
# to work correctly.
|
||||||
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
||||||
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||||
|
small_conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.2), (16000, 0.1), default=0),
|
||||||
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||||
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
||||||
ff2_skip_rate: FloatLike = 0.01,
|
ff2_skip_rate: FloatLike = 0.01,
|
||||||
@ -422,6 +423,11 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# an additional skip probability that applies to ConvModule to stop it from
|
# an additional skip probability that applies to ConvModule to stop it from
|
||||||
# contributing too much early on.
|
# contributing too much early on.
|
||||||
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
|
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
|
||||||
|
|
||||||
|
# skip rate for small_conv_module; it is fairly high and remains nonzero
|
||||||
|
# because we don't want this submodule to contribute too much.
|
||||||
|
self.small_conv_skip_rate = copy.deepcopy(small_conv_skip_rate)
|
||||||
|
|
||||||
# ff2_skip_rate is to prevent the ff2 module from having output that's too big
|
# ff2_skip_rate is to prevent the ff2 module from having output that's too big
|
||||||
# compared to its residual.
|
# compared to its residual.
|
||||||
self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
|
self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
|
||||||
@ -452,6 +458,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
self.nonlin_attention = NonlinAttention(embed_dim,
|
self.nonlin_attention = NonlinAttention(embed_dim,
|
||||||
hidden_channels=embed_dim // 4)
|
hidden_channels=embed_dim // 4)
|
||||||
|
|
||||||
|
self.small_conv_module = SmallConvolutionModule(embed_dim)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(embed_dim,
|
self.conv_module = ConvolutionModule(embed_dim,
|
||||||
cnn_module_kernel)
|
cnn_module_kernel)
|
||||||
@ -596,6 +603,10 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.balancer_na(self.nonlin_attention(src,
|
src = src + self.balancer_na(self.nonlin_attention(src,
|
||||||
selected_attn_weights[0:1]))
|
selected_attn_weights[0:1]))
|
||||||
|
|
||||||
|
|
||||||
|
if torch.jit.is_scripting() or random.random() >= float(self.small_conv_skip_rate):
|
||||||
|
src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
src = src + self.feed_forward1(src)
|
src = src + self.feed_forward1(src)
|
||||||
|
|
||||||
# pooling module
|
# pooling module
|
||||||
@ -937,6 +948,92 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SmallConvolutionModule(nn.Module):
|
||||||
|
"""Part of Zipformer model: a small version of the Convolution module that uses a small kernel.
|
||||||
|
Inspired by convnext (i.e. have the depthwise conv first.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): The number of channels of conv layers.
|
||||||
|
kernel_size (int): Kernerl size of conv layers.
|
||||||
|
bias (bool): Whether to use bias in conv layers (default=True).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, channels: int,
|
||||||
|
hidden_dim: int = 128,
|
||||||
|
kernel_size: int = 5,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
self.depthwise_conv = nn.Conv1d(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
groups=channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=kernel_size // 2)
|
||||||
|
|
||||||
|
self.pointwise_conv1 = nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
hidden_dim,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
# balancer and activation as tuned for ConvolutionModule.
|
||||||
|
|
||||||
|
self.balancer = Balancer(
|
||||||
|
hidden_dim, channel_dim=1,
|
||||||
|
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||||
|
max_positive=1.0,
|
||||||
|
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
|
||||||
|
max_abs=10.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.activation = SwooshR()
|
||||||
|
|
||||||
|
self.pointwise_conv2 = ScaledConv1d(
|
||||||
|
hidden_dim,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=True,
|
||||||
|
initial_scale=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: Tensor,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Compute convolution module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor (#time, batch, channels).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional):
|
||||||
|
(batch, #time), contains bool in masked positions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (#time, batch, channels).
|
||||||
|
"""
|
||||||
|
# exchange the temporal dimension and the feature dimension
|
||||||
|
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||||
|
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||||
|
|
||||||
|
x = self.depthwise_conv(x) # (batch, channels, time)
|
||||||
|
x = self.pointwise_conv1(x) # (batch, hidden_dim, time)
|
||||||
|
|
||||||
|
x = self.balancer(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.pointwise_conv2(x)
|
||||||
|
|
||||||
|
return x.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Relative positional encoding module. This version is "compact" meaning it is able to encode
|
Relative positional encoding module. This version is "compact" meaning it is able to encode
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user