Introduce a scale dependent on the masking value
This commit is contained in:
parent
1be455438a
commit
93dff29243
@ -18,17 +18,18 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import (
|
from s import (
|
||||||
ActivationBalancer,
|
ActivationBalancer,
|
||||||
BasicNorm,
|
BasicNorm,
|
||||||
DoubleSwish,
|
DoubleSwish,
|
||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
|
LearnedScale,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -171,6 +172,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.self_attn = RelPositionMultiheadAttention(
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
d_model, nhead, dropout=dropout,
|
d_model, nhead, dropout=dropout,
|
||||||
)
|
)
|
||||||
|
self.self_attn_scale = LearnedScale()
|
||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
@ -181,6 +183,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
ScaledLinear(dim_feedforward, d_model,
|
ScaledLinear(dim_feedforward, d_model,
|
||||||
initial_scale=0.1),
|
initial_scale=0.1),
|
||||||
)
|
)
|
||||||
|
self.feed_forward_scale = LearnedScale()
|
||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
@ -191,11 +194,14 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
ScaledLinear(dim_feedforward, d_model,
|
ScaledLinear(dim_feedforward, d_model,
|
||||||
initial_scale=0.1),
|
initial_scale=0.1),
|
||||||
)
|
)
|
||||||
|
self.feed_forward_macaron_scale = LearnedScale()
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model,
|
self.conv_module = ConvolutionModule(d_model,
|
||||||
cnn_module_kernel)
|
cnn_module_kernel)
|
||||||
|
self.conv_scale = LearnedScale()
|
||||||
|
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
|
self.final_scale = LearnedScale()
|
||||||
|
|
||||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||||
self.balancer = ActivationBalancer(
|
self.balancer = ActivationBalancer(
|
||||||
@ -209,11 +215,11 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
|
feature_mask: Union[Tensor, float],
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
attn_scores_in: Optional[Tensor] = None,
|
attn_scores_in: Optional[Tensor] = None,
|
||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
feature_mask: Optional[Tensor] = None,
|
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
@ -233,10 +239,10 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
|
feature_mask: float, or (S, N, 1)
|
||||||
pos_emb: (N, 2*S-1, E)
|
pos_emb: (N, 2*S-1, E)
|
||||||
src_mask: (S, S).
|
src_mask: (S, S).
|
||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
feature_mask: (S, N, E)
|
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
"""
|
"""
|
||||||
src_orig = src
|
src_orig = src
|
||||||
@ -254,7 +260,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
alpha = 1.0
|
alpha = 1.0
|
||||||
|
|
||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = src + self.feed_forward_macaron(src)
|
src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src),
|
||||||
|
feature_mask)
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
src_att, _, attn_scores_out = self.self_attn(
|
src_att, _, attn_scores_out = self.self_attn(
|
||||||
@ -264,25 +271,24 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
src = src + src_att
|
src = src + self.self_attn_scale(src_att, feature_mask)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask),
|
||||||
|
feature_mask)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.feed_forward(src)
|
src = src + self.feed_forward_scale(self.feed_forward(src),
|
||||||
|
feature_mask)
|
||||||
|
|
||||||
|
src = self.final_scale(src, feature_mask)
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
if alpha != 1.0:
|
if alpha != 1.0:
|
||||||
src = alpha * src + (1 - alpha) * src_orig
|
src = alpha * src + (1 - alpha) * src_orig
|
||||||
|
|
||||||
if feature_mask is not None:
|
|
||||||
src = src * feature_mask
|
|
||||||
|
|
||||||
return src, attn_scores_out
|
return src, attn_scores_out
|
||||||
|
|
||||||
|
|
||||||
@ -359,23 +365,28 @@ class ConformerEncoder(nn.Module):
|
|||||||
feature_mask_dropout_prob = 0.15
|
feature_mask_dropout_prob = 0.15
|
||||||
feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked.
|
feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked.
|
||||||
|
|
||||||
feature_mask = torch.ones_like(src) # S, N, E
|
full_feature_mask = torch.ones_like(src) # S, N, E
|
||||||
# is_masked_frame is 0 with probability `feature_mask_dropout_prob`
|
# feature_mask is 0 with probability `feature_mask_dropout_prob`
|
||||||
is_masked_frame = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
|
# feature_mask shape: (S, N, 1)
|
||||||
feature_mask[..., feature_unmasked_dim:] *= is_masked_frame
|
feature_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
|
||||||
|
full_feature_mask[..., feature_unmasked_dim:] *= feature_mask
|
||||||
else:
|
else:
|
||||||
feature_mask = None
|
feature_mask = 1.0
|
||||||
|
full_feature_mask = 1.0
|
||||||
|
|
||||||
|
src = src * full_feature_mask
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
output, attn_scores = mod(
|
output, attn_scores = mod(
|
||||||
output,
|
output,
|
||||||
|
feature_mask,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
attn_scores,
|
attn_scores,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
feature_mask=feature_mask,
|
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
|
output = output * full_feature_mask
|
||||||
if i in self.aux_layers:
|
if i in self.aux_layers:
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
|
|||||||
@ -326,6 +326,30 @@ def ScaledConv1d(*args,
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
class LearnedScale(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Module that learns a scale dependent on some kind of mask that is typically going to be 0 or 1
|
||||||
|
in training. The scale will be 1.0 if the mask is 1.0, but may be a different (learned) value
|
||||||
|
if the mask value is not 1.0.
|
||||||
|
|
||||||
|
The idea is that if we have some kind of feature mask that would always be 1.0 in
|
||||||
|
test mode but might sometimes be 0.0 in training mode, we might want the multiply
|
||||||
|
the remaining features by a value dependent on this mask.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(LearnedScale, self).__init__()
|
||||||
|
self.alpha = nn.Parameter(torch.tensor(0.0))
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: Tensor,
|
||||||
|
mask: Tensor):
|
||||||
|
"""
|
||||||
|
Mask should either be a number (probably 1.0) or a tensors that broadcasts with x.
|
||||||
|
"""
|
||||||
|
if self.training and mask is 1.0:
|
||||||
|
return x
|
||||||
|
return x * (1.0 + self.alpha * (1.0 - mask))
|
||||||
|
|
||||||
|
|
||||||
class ActivationBalancer(torch.nn.Module):
|
class ActivationBalancer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user