mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Replace most normalizations with scales (still have norm in conv)
This commit is contained in:
parent
059b57ad37
commit
b55472bb42
@ -58,6 +58,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
ExpScaleRelu(odim, 1, 1, speed=20.0),
|
||||
)
|
||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
self.out_norm = BasicNorm(odim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
@ -76,7 +77,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = x * 0.1
|
||||
x = self.out_norm(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -200,9 +201,11 @@ class PeLU(torch.nn.Module):
|
||||
return PeLUFunction.apply(x, self.cutoff, self.alpha)
|
||||
|
||||
class ExpScale(torch.nn.Module):
|
||||
def __init__(self, *shape, speed: float = 1.0):
|
||||
def __init__(self, *shape, speed: float = 1.0, initial_scale: float = 1.0):
|
||||
super(ExpScale, self).__init__()
|
||||
self.scale = nn.Parameter(torch.zeros(*shape))
|
||||
scale = torch.tensor(initial_scale)
|
||||
scale = scale.log() / speed
|
||||
self.scale = nn.Parameter(scale.detach())
|
||||
self.speed = speed
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
@ -19,7 +19,7 @@ import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Sequence
|
||||
from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer
|
||||
from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer, BasicNorm
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@ -150,6 +150,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
normalize_before: bool = True,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
self.d_model = d_model
|
||||
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=0.0
|
||||
)
|
||||
@ -174,22 +176,15 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
self.norm_ff_macaron = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the macaron style FNN module
|
||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||
self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2)
|
||||
self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5)
|
||||
self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5)
|
||||
self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5)
|
||||
|
||||
self.ff_scale = 0.5
|
||||
|
||||
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||
self.norm_final = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the final output of the block
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -217,18 +212,15 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
# macaron style feed forward module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
src = residual + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(src)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
|
||||
|
||||
src = src + self.dropout(self.feed_forward_macaron(
|
||||
self.scale_ff_macaron(src)))
|
||||
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_mha(src)
|
||||
src = self.scale_mha(src)
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
@ -238,26 +230,13 @@ class ConformerEncoderLayer(nn.Module):
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)[0]
|
||||
src = residual + self.dropout(src_att)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_mha(src)
|
||||
|
||||
# convolution module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
src = residual + self.dropout(self.conv_module(src))
|
||||
if not self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
src = residual + self.dropout(self.conv_module(self.scale_conv(src)))
|
||||
|
||||
# feed forward module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff(src)
|
||||
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff(src)
|
||||
src = src + self.dropout(self.feed_forward(self.scale_ff(src)))
|
||||
|
||||
if self.normalize_before:
|
||||
src = self.norm_final(src)
|
||||
|
||||
return src
|
||||
@ -288,7 +267,7 @@ class ConformerEncoder(nn.Module):
|
||||
self.aux_layers = set(aux_layers + [num_layers - 1])
|
||||
assert num_layers - 1 not in aux_layers
|
||||
self.num_layers = num_layers
|
||||
num_channels = encoder_layer.norm_final.weight.numel()
|
||||
num_channels = encoder_layer.d_model
|
||||
self.combiner = RandomCombine(num_inputs=len(self.aux_layers),
|
||||
num_channels=num_channels,
|
||||
final_weight=0.5,
|
||||
|
Loading…
x
Reference in New Issue
Block a user