Remove changes in previous merge commit that did not relate to length_factor.

This commit is contained in:
Daniel Povey 2022-11-21 14:32:05 +08:00
parent a6770657c8
commit 211e3af680
2 changed files with 2 additions and 81 deletions

View File

@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--feedforward-dim", "--feedforward-dim",
type=str, type=str,
default="1280,1280,1280,1792,1280,1280", default="1536,1536,2048,1536,1536,1536",
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
) )

View File

@ -404,8 +404,6 @@ class ZipformerEncoderLayer(nn.Module):
self.nonlin_attention_module = NonlinAttentionModule(embed_dim) self.nonlin_attention_module = NonlinAttentionModule(embed_dim)
self.small_conv_module = SmallConvolutionModule(embed_dim)
self.conv_module = ConvolutionModule(embed_dim, self.conv_module = ConvolutionModule(embed_dim,
cnn_module_kernel) cnn_module_kernel)
@ -471,8 +469,6 @@ class ZipformerEncoderLayer(nn.Module):
# multi-headed self-attention module # multi-headed self-attention module
use_self_attn = (random.random() >= dynamic_skip_rate) use_self_attn = (random.random() >= dynamic_skip_rate)
src = src + self.feed_forward1(src)
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
# attn_weights: (num_heads, batch_size, seq_len, seq_len) # attn_weights: (num_heads, batch_size, seq_len, seq_len)
attn_weights = self.self_attn_weights( attn_weights = self.self_attn_weights(
@ -487,8 +483,7 @@ class ZipformerEncoderLayer(nn.Module):
attn_weights[0:1]) attn_weights[0:1])
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate: src = src + self.feed_forward1(src)
src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask)
# pooling module # pooling module
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
@ -1575,80 +1570,6 @@ class ConvolutionModule(nn.Module):
return x.permute(2, 0, 1) return x.permute(2, 0, 1)
class SmallConvolutionModule(nn.Module):
"""Part of Zipformer model: a small version of the Convolution module that uses a small kernel.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self, channels: int, hidden_dim: int = 256,
) -> None:
super().__init__()
self.conv1 = nn.Conv1d(
channels,
hidden_dim,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
self.deriv_balancer = ActivationBalancer(
hidden_dim, channel_dim=1,
min_positive=0.05, max_positive=1.0,
max_abs=20.0,
)
self.activation = DoubleSwish()
self.conv2 = ScaledConv1d(
hidden_dim,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=True,
initial_scale=0.05,
)
def forward(self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional):
(batch, #time), contains bool in masked positions.
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
x = self.conv1(x) # (batch, hidden_dim, time)
x = self.deriv_balancer(x)
x = self.activation(x)
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.conv2(x)
return x.permute(2, 0, 1)
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length). """Convolutional 2D subsampling (to 1/2 length).