Replace most normalizations with scales (still have norm in conv)

This commit is contained in:
Daniel Povey 2022-03-10 14:43:54 +08:00
parent 059b57ad37
commit b55472bb42
2 changed files with 24 additions and 42 deletions

View File

@ -58,6 +58,7 @@ class Conv2dSubsampling(nn.Module):
ExpScaleRelu(odim, 1, 1, speed=20.0),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.out_norm = BasicNorm(odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
@ -76,7 +77,7 @@ class Conv2dSubsampling(nn.Module):
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = x * 0.1
x = self.out_norm(x)
return x
@ -200,9 +201,11 @@ class PeLU(torch.nn.Module):
return PeLUFunction.apply(x, self.cutoff, self.alpha)
class ExpScale(torch.nn.Module):
def __init__(self, *shape, speed: float = 1.0):
def __init__(self, *shape, speed: float = 1.0, initial_scale: float = 1.0):
super(ExpScale, self).__init__()
self.scale = nn.Parameter(torch.zeros(*shape))
scale = torch.tensor(initial_scale)
scale = scale.log() / speed
self.scale = nn.Parameter(scale.detach())
self.speed = speed
def forward(self, x: Tensor) -> Tensor:

View File

@ -19,7 +19,7 @@ import copy
import math
import warnings
from typing import Optional, Tuple, Sequence
from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer
from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer, BasicNorm
import torch
from torch import Tensor, nn
@ -150,6 +150,8 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.d_model = d_model
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)
@ -174,22 +176,15 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2)
self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5)
self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5)
self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5)
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.norm_final = BasicNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.normalize_before = normalize_before
def forward(
self,
@ -217,18 +212,15 @@ class ConformerEncoderLayer(nn.Module):
# macaron style feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
src = src + self.dropout(self.feed_forward_macaron(
self.scale_ff_macaron(src)))
# multi-headed self-attention module
residual = src
if self.normalize_before:
src = self.norm_mha(src)
src = self.scale_mha(src)
src_att = self.self_attn(
src,
src,
@ -238,26 +230,13 @@ class ConformerEncoderLayer(nn.Module):
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
if not self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(self.scale_conv(src)))
# feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
src = src + self.dropout(self.feed_forward(self.scale_ff(src)))
if self.normalize_before:
src = self.norm_final(src)
return src
@ -288,7 +267,7 @@ class ConformerEncoder(nn.Module):
self.aux_layers = set(aux_layers + [num_layers - 1])
assert num_layers - 1 not in aux_layers
self.num_layers = num_layers
num_channels = encoder_layer.norm_final.weight.numel()
num_channels = encoder_layer.d_model
self.combiner = RandomCombine(num_inputs=len(self.aux_layers),
num_channels=num_channels,
final_weight=0.5,