mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Bug fix; also make the final norm of Conv2dSubsampling a ConvNorm1d
This commit is contained in:
parent
3b4b33af58
commit
71880409cc
@ -511,7 +511,8 @@ class PositiveConv1d(nn.Conv1d):
|
||||
self.max = max
|
||||
|
||||
# initialize weight to all positive values.
|
||||
self.weight[:] = 1.0 / self.weight[0][0].numel()
|
||||
with torch.no_grad():
|
||||
self.weight[:] = 1.0 / self.weight[0][0].numel()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
"""
|
||||
@ -519,7 +520,6 @@ class PositiveConv1d(nn.Conv1d):
|
||||
(N, C, H)
|
||||
i.e. (batch_size, num_channels, height)
|
||||
"""
|
||||
weight = self.weight
|
||||
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max))
|
||||
# make absolutely sure there are no negative values. For parameter-averaging-related
|
||||
# reasons, we prefer to also use limit_param_value to make sure the weights stay
|
||||
@ -556,7 +556,6 @@ class ConvNorm1d(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
eps: float = 0.25,
|
||||
learn_eps: bool = True,
|
||||
eps_min: float = -3.0,
|
||||
@ -567,7 +566,6 @@ class ConvNorm1d(torch.nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
if learn_eps:
|
||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||
else:
|
||||
@ -576,7 +574,8 @@ class ConvNorm1d(torch.nn.Module):
|
||||
self.eps_max = eps_max
|
||||
pad = kernel_size // 2
|
||||
# it has bias=False.
|
||||
self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad)
|
||||
self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad,
|
||||
min=conv_min, max=conv_max)
|
||||
|
||||
|
||||
def forward(self, x: Tensor,
|
||||
@ -598,7 +597,7 @@ class ConvNorm1d(torch.nn.Module):
|
||||
|
||||
# gradients to allow the parameter to get back into the allowed
|
||||
# region if it happens to exit it.
|
||||
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
||||
eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max)
|
||||
|
||||
# sqnorms: (N, 1, T)
|
||||
sqnorms = (
|
||||
@ -611,9 +610,9 @@ class ConvNorm1d(torch.nn.Module):
|
||||
sqnorms = sqnorms * counts
|
||||
sqnorms = self.conv(sqnorms)
|
||||
# the clamping is to avoid division by zero for padding frames.
|
||||
counts = self.conv(counts).clamp(min=0.01)
|
||||
counts = torch.clamp(self.conv(counts), min=0.01)
|
||||
# scales: (N, 1, T)
|
||||
scales = (sqnorms / counts + eps.exp()) ** -0.5
|
||||
scales = (sqnorms / counts + eps.exp()) ** -0.5 #
|
||||
return x * scales
|
||||
|
||||
|
||||
|
||||
@ -1824,7 +1824,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
self.out = nn.Linear(out_height * layer3_channels, out_channels)
|
||||
|
||||
self.out_norm = BasicNorm(out_channels, channel_dim=-1)
|
||||
self.out_norm = ConvNorm1d(out_channels)
|
||||
self.dropout = Dropout2(dropout)
|
||||
|
||||
|
||||
@ -1859,9 +1859,11 @@ class Conv2dSubsampling(nn.Module):
|
||||
min=float(self.scale_min),
|
||||
max=float(self.scale_max))
|
||||
|
||||
x = self.out(x)
|
||||
x = self.out_norm(x)
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.out(x)
|
||||
x = x.transpose(1, 2) # (batch, channels, time)
|
||||
x = self.out_norm(x)
|
||||
x = x.transpose(1, 2) # (batch, time=((T-1)//2 - 1))//2, channels)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user