Bug fix; also make the final norm of Conv2dSubsampling a ConvNorm1d

This commit is contained in:
Daniel Povey 2022-12-20 19:44:04 +08:00
parent 3b4b33af58
commit 71880409cc
2 changed files with 12 additions and 11 deletions

View File

@ -511,7 +511,8 @@ class PositiveConv1d(nn.Conv1d):
self.max = max
# initialize weight to all positive values.
self.weight[:] = 1.0 / self.weight[0][0].numel()
with torch.no_grad():
self.weight[:] = 1.0 / self.weight[0][0].numel()
def forward(self, input: Tensor) -> Tensor:
"""
@ -519,7 +520,6 @@ class PositiveConv1d(nn.Conv1d):
(N, C, H)
i.e. (batch_size, num_channels, height)
"""
weight = self.weight
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max))
# make absolutely sure there are no negative values. For parameter-averaging-related
# reasons, we prefer to also use limit_param_value to make sure the weights stay
@ -556,7 +556,6 @@ class ConvNorm1d(torch.nn.Module):
def __init__(
self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25,
learn_eps: bool = True,
eps_min: float = -3.0,
@ -567,7 +566,6 @@ class ConvNorm1d(torch.nn.Module):
) -> None:
super().__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
if learn_eps:
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else:
@ -576,7 +574,8 @@ class ConvNorm1d(torch.nn.Module):
self.eps_max = eps_max
pad = kernel_size // 2
# it has bias=False.
self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad)
self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad,
min=conv_min, max=conv_max)
def forward(self, x: Tensor,
@ -598,7 +597,7 @@ class ConvNorm1d(torch.nn.Module):
# gradients to allow the parameter to get back into the allowed
# region if it happens to exit it.
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max)
# sqnorms: (N, 1, T)
sqnorms = (
@ -611,9 +610,9 @@ class ConvNorm1d(torch.nn.Module):
sqnorms = sqnorms * counts
sqnorms = self.conv(sqnorms)
# the clamping is to avoid division by zero for padding frames.
counts = self.conv(counts).clamp(min=0.01)
counts = torch.clamp(self.conv(counts), min=0.01)
# scales: (N, 1, T)
scales = (sqnorms / counts + eps.exp()) ** -0.5
scales = (sqnorms / counts + eps.exp()) ** -0.5 #
return x * scales

View File

@ -1824,7 +1824,7 @@ class Conv2dSubsampling(nn.Module):
self.out = nn.Linear(out_height * layer3_channels, out_channels)
self.out_norm = BasicNorm(out_channels, channel_dim=-1)
self.out_norm = ConvNorm1d(out_channels)
self.dropout = Dropout2(dropout)
@ -1859,9 +1859,11 @@ class Conv2dSubsampling(nn.Module):
min=float(self.scale_min),
max=float(self.scale_max))
x = self.out(x)
x = self.out_norm(x)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out(x)
x = x.transpose(1, 2) # (batch, channels, time)
x = self.out_norm(x)
x = x.transpose(1, 2) # (batch, time=((T-1)//2 - 1))//2, channels)
x = self.dropout(x)
return x