Replace BasicNorm of encoder layers with ConvNorm1d

This commit is contained in:
Daniel Povey 2022-12-20 19:15:14 +08:00
parent f59697555f
commit 494139d27a
2 changed files with 126 additions and 4 deletions

View File

@ -498,6 +498,125 @@ class BasicNorm(torch.nn.Module):
class PositiveConv1d(nn.Conv1d):
"""
A modified form of nn.Conv1d where the weight parameters are constrained
to be positive and there is no bias.
"""
def __init__(
self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0,
**kwargs):
super().__init__(*args, **kwargs, bias=False)
self.min = min
self.max = max
# initialize weight to all positive values.
self.weight[:] = 1.0 / self.weight[0][0].numel()
def forward(self, input: Tensor) -> Tensor:
"""
Forward function. Input and returned tensor have shape:
(N, C, H)
i.e. (batch_size, num_channels, height)
"""
weight = self.weight
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max))
# make absolutely sure there are no negative values. For parameter-averaging-related
# reasons, we prefer to also use limit_param_value to make sure the weights stay
# positive.
weight = weight.abs()
if self.padding_mode != 'zeros':
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
class ConvNorm1d(torch.nn.Module):
"""
This is like BasicNorm except the denominator is summed over time using
convolution with positive weights.
Args:
num_channels: the number of channels, e.g. 512.
eps: the initial "epsilon" that we add as ballast in:
scale = ((input_vec**2).mean() + epsilon)**-0.5
Note: our epsilon is actually large, but we keep the name
to indicate the connection with conventional LayerNorm.
learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value.
eps_min: float
eps_max: float
"""
def __init__(
self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25,
learn_eps: bool = True,
eps_min: float = -3.0,
eps_max: float = 3.0,
conv_min: float = 0.1,
conv_max: float = 1.0,
kernel_size: int = 15,
) -> None:
super().__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
if learn_eps:
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else:
self.register_buffer("eps", torch.tensor(eps).log().detach())
self.eps_min = eps_min
self.eps_max = eps_max
pad = kernel_size // 2
# it has bias=False.
self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad)
def forward(self, x: Tensor,
src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
"""
x shape: (N, C, T)
src_key_padding_mask: the mask for the src keys per batch (optional):
(N, T), contains True in masked positions.
"""
assert x.ndim == 3 and x.shape[1] == self.num_channels
eps = self.eps
if self.training and random.random() < 0.25:
# with probability 0.25, in training mode, clamp eps between the min
# and max; this will encourage it to learn parameters within the
# allowed range by making parameters that are outside the allowed
# range noisy.
# gradients to allow the parameter to get back into the allowed
# region if it happens to exit it.
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
# sqnorms: (N, 1, T)
sqnorms = (
torch.mean(x ** 2, dim=1, keepdim=True)
)
# 'counts' is a mechanism to correct for edge effects.
counts = torch.ones_like(sqnorms)
if src_key_padding_mask is not None:
counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0)
sqnorms = sqnorms * counts
sqnorms = self.conv(sqnorms)
counts = self.conv(counts)
# scales: (N, 1, T)
scales = (sqnorms / counts + eps.exp()) ** -0.5
return x * scales
def ScaledLinear(*args,
initial_scale: float = 1.0,

View File

@ -27,6 +27,7 @@ from encoder_interface import EncoderInterface
from scaling import (
ActivationBalancer,
BasicNorm,
ConvNorm1d,
Dropout2,
MaxEig,
DoubleSwish,
@ -443,7 +444,7 @@ class ZipformerEncoderLayer(nn.Module):
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
self.norm_final = BasicNorm(embed_dim)
self.norm_final = ConvNorm1d(embed_dim)
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
@ -555,8 +556,10 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.feed_forward2(src)
src = self.norm_final(self.balancer(src))
src = self.balancer(src)
src = src.permute(1, 2, 0) # (batch, channels, time)
src = self.norm_final(src, src_key_padding_mask)
src = src.permute(2, 0, 1) # (time, batch, channels)
delta = src - src_orig
@ -1606,7 +1609,7 @@ class ConvolutionModule(nn.Module):
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional):
(batch, #time), contains bool in masked positions.
(batch, #time), contains True in masked positions.
Returns:
Tensor: Output tensor (#time, batch, channels).