mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Replace BasicNorm of encoder layers with ConvNorm1d
This commit is contained in:
parent
f59697555f
commit
494139d27a
@ -498,6 +498,125 @@ class BasicNorm(torch.nn.Module):
|
||||
|
||||
|
||||
|
||||
class PositiveConv1d(nn.Conv1d):
|
||||
"""
|
||||
A modified form of nn.Conv1d where the weight parameters are constrained
|
||||
to be positive and there is no bias.
|
||||
"""
|
||||
def __init__(
|
||||
self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs, bias=False)
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
||||
# initialize weight to all positive values.
|
||||
self.weight[:] = 1.0 / self.weight[0][0].numel()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
"""
|
||||
Forward function. Input and returned tensor have shape:
|
||||
(N, C, H)
|
||||
i.e. (batch_size, num_channels, height)
|
||||
"""
|
||||
weight = self.weight
|
||||
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max))
|
||||
# make absolutely sure there are no negative values. For parameter-averaging-related
|
||||
# reasons, we prefer to also use limit_param_value to make sure the weights stay
|
||||
# positive.
|
||||
weight = weight.abs()
|
||||
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
weight, self.bias, self.stride,
|
||||
_single(0), self.dilation, self.groups)
|
||||
return F.conv1d(input, weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
|
||||
class ConvNorm1d(torch.nn.Module):
|
||||
"""
|
||||
This is like BasicNorm except the denominator is summed over time using
|
||||
convolution with positive weights.
|
||||
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
eps: the initial "epsilon" that we add as ballast in:
|
||||
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
||||
Note: our epsilon is actually large, but we keep the name
|
||||
to indicate the connection with conventional LayerNorm.
|
||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||
at the initial value.
|
||||
eps_min: float
|
||||
eps_max: float
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
eps: float = 0.25,
|
||||
learn_eps: bool = True,
|
||||
eps_min: float = -3.0,
|
||||
eps_max: float = 3.0,
|
||||
conv_min: float = 0.1,
|
||||
conv_max: float = 1.0,
|
||||
kernel_size: int = 15,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
if learn_eps:
|
||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||
else:
|
||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||
self.eps_min = eps_min
|
||||
self.eps_max = eps_max
|
||||
pad = kernel_size // 2
|
||||
# it has bias=False.
|
||||
self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad)
|
||||
|
||||
|
||||
def forward(self, x: Tensor,
|
||||
src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
||||
"""
|
||||
x shape: (N, C, T)
|
||||
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional):
|
||||
(N, T), contains True in masked positions.
|
||||
|
||||
"""
|
||||
assert x.ndim == 3 and x.shape[1] == self.num_channels
|
||||
eps = self.eps
|
||||
if self.training and random.random() < 0.25:
|
||||
# with probability 0.25, in training mode, clamp eps between the min
|
||||
# and max; this will encourage it to learn parameters within the
|
||||
# allowed range by making parameters that are outside the allowed
|
||||
# range noisy.
|
||||
|
||||
# gradients to allow the parameter to get back into the allowed
|
||||
# region if it happens to exit it.
|
||||
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
||||
|
||||
# sqnorms: (N, 1, T)
|
||||
sqnorms = (
|
||||
torch.mean(x ** 2, dim=1, keepdim=True)
|
||||
)
|
||||
# 'counts' is a mechanism to correct for edge effects.
|
||||
counts = torch.ones_like(sqnorms)
|
||||
if src_key_padding_mask is not None:
|
||||
counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0)
|
||||
sqnorms = sqnorms * counts
|
||||
sqnorms = self.conv(sqnorms)
|
||||
counts = self.conv(counts)
|
||||
# scales: (N, 1, T)
|
||||
scales = (sqnorms / counts + eps.exp()) ** -0.5
|
||||
return x * scales
|
||||
|
||||
|
||||
|
||||
|
||||
def ScaledLinear(*args,
|
||||
initial_scale: float = 1.0,
|
||||
|
||||
@ -27,6 +27,7 @@ from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
ConvNorm1d,
|
||||
Dropout2,
|
||||
MaxEig,
|
||||
DoubleSwish,
|
||||
@ -443,7 +444,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||
|
||||
self.norm_final = BasicNorm(embed_dim)
|
||||
self.norm_final = ConvNorm1d(embed_dim)
|
||||
|
||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||
|
||||
@ -555,8 +556,10 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
src = src + self.feed_forward2(src)
|
||||
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
src = self.balancer(src)
|
||||
src = src.permute(1, 2, 0) # (batch, channels, time)
|
||||
src = self.norm_final(src, src_key_padding_mask)
|
||||
src = src.permute(2, 0, 1) # (time, batch, channels)
|
||||
|
||||
delta = src - src_orig
|
||||
|
||||
@ -1606,7 +1609,7 @@ class ConvolutionModule(nn.Module):
|
||||
Args:
|
||||
x: Input tensor (#time, batch, channels).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional):
|
||||
(batch, #time), contains bool in masked positions.
|
||||
(batch, #time), contains True in masked positions.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user