mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Revert ConvNorm1d to BasicNorm in Conv2dSubsampling and ZipformerLayer to BasicNorm
This commit is contained in:
parent
0995970f29
commit
678be7a2eb
@ -451,7 +451,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||||
|
|
||||||
self.norm_final = ConvNorm1d(embed_dim)
|
self.norm_final = BasicNorm(embed_dim)
|
||||||
|
|
||||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||||
|
|
||||||
@ -571,9 +571,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.feed_forward2(src)
|
src = src + self.feed_forward2(src)
|
||||||
|
|
||||||
src = self.balancer(src)
|
src = self.balancer(src)
|
||||||
src = src.permute(1, 2, 0) # (batch, channels, time)
|
src = self.norm_final(src)
|
||||||
src = self.norm_final(src, src_key_padding_mask)
|
|
||||||
src = src.permute(2, 0, 1) # (time, batch, channels)
|
|
||||||
|
|
||||||
delta = src - src_orig
|
delta = src - src_orig
|
||||||
|
|
||||||
@ -1847,7 +1845,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
self.out = nn.Linear(out_height * layer3_channels, out_channels)
|
self.out = nn.Linear(out_height * layer3_channels, out_channels)
|
||||||
|
|
||||||
self.out_norm = ConvNorm1d(out_channels)
|
self.out_norm = BasicNorm(out_channels)
|
||||||
self.dropout = Dropout2(dropout)
|
self.dropout = Dropout2(dropout)
|
||||||
|
|
||||||
|
|
||||||
@ -1884,9 +1882,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||||
x = self.out(x)
|
x = self.out(x)
|
||||||
x = x.transpose(1, 2) # (batch, channels, time)
|
|
||||||
x = self.out_norm(x)
|
x = self.out_norm(x)
|
||||||
x = x.transpose(1, 2) # (batch, time=((T-1)//2 - 1))//2, channels)
|
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user