Replace most normalizations with scales (still have norm in conv)

This commit is contained in:
Daniel Povey 2022-03-10 14:43:54 +08:00
parent 059b57ad37
commit b55472bb42
2 changed files with 24 additions and 42 deletions

View File

@ -58,6 +58,7 @@ class Conv2dSubsampling(nn.Module):
ExpScaleRelu(odim, 1, 1, speed=20.0), ExpScaleRelu(odim, 1, 1, speed=20.0),
) )
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.out_norm = BasicNorm(odim)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.
@ -76,7 +77,7 @@ class Conv2dSubsampling(nn.Module):
b, c, t, f = x.size() b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = x * 0.1 x = self.out_norm(x)
return x return x
@ -200,9 +201,11 @@ class PeLU(torch.nn.Module):
return PeLUFunction.apply(x, self.cutoff, self.alpha) return PeLUFunction.apply(x, self.cutoff, self.alpha)
class ExpScale(torch.nn.Module): class ExpScale(torch.nn.Module):
def __init__(self, *shape, speed: float = 1.0): def __init__(self, *shape, speed: float = 1.0, initial_scale: float = 1.0):
super(ExpScale, self).__init__() super(ExpScale, self).__init__()
self.scale = nn.Parameter(torch.zeros(*shape)) scale = torch.tensor(initial_scale)
scale = scale.log() / speed
self.scale = nn.Parameter(scale.detach())
self.speed = speed self.speed = speed
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:

View File

@ -19,7 +19,7 @@ import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple, Sequence from typing import Optional, Tuple, Sequence
from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer, BasicNorm
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -150,6 +150,8 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.d_model = d_model
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0 d_model, nhead, dropout=0.0
) )
@ -174,22 +176,15 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm( self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2)
d_model self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5)
) # for the macaron style FNN module self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5)
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5)
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5 self.norm_final = BasicNorm(d_model)
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.normalize_before = normalize_before
def forward( def forward(
self, self,
@ -217,18 +212,15 @@ class ConformerEncoderLayer(nn.Module):
# macaron style feed forward module # macaron style feed forward module
residual = src residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout( src = src + self.dropout(self.feed_forward_macaron(
self.feed_forward_macaron(src) self.scale_ff_macaron(src)))
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# multi-headed self-attention module # multi-headed self-attention module
residual = src residual = src
if self.normalize_before: src = self.scale_mha(src)
src = self.norm_mha(src)
src_att = self.self_attn( src_att = self.self_attn(
src, src,
src, src,
@ -238,27 +230,14 @@ class ConformerEncoderLayer(nn.Module):
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
)[0] )[0]
src = residual + self.dropout(src_att) src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module # convolution module
residual = src src = residual + self.dropout(self.conv_module(self.scale_conv(src)))
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
if not self.normalize_before:
src = self.norm_conv(src)
# feed forward module # feed forward module
residual = src src = src + self.dropout(self.feed_forward(self.scale_ff(src)))
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
if self.normalize_before: src = self.norm_final(src)
src = self.norm_final(src)
return src return src
@ -288,7 +267,7 @@ class ConformerEncoder(nn.Module):
self.aux_layers = set(aux_layers + [num_layers - 1]) self.aux_layers = set(aux_layers + [num_layers - 1])
assert num_layers - 1 not in aux_layers assert num_layers - 1 not in aux_layers
self.num_layers = num_layers self.num_layers = num_layers
num_channels = encoder_layer.norm_final.weight.numel() num_channels = encoder_layer.d_model
self.combiner = RandomCombine(num_inputs=len(self.aux_layers), self.combiner = RandomCombine(num_inputs=len(self.aux_layers),
num_channels=num_channels, num_channels=num_channels,
final_weight=0.5, final_weight=0.5,