Introduce a scale dependent on the masking value

This commit is contained in:
Daniel Povey 2022-10-03 14:34:37 +08:00
parent 1be455438a
commit 93dff29243
2 changed files with 54 additions and 19 deletions

View File

@ -18,17 +18,18 @@
import copy
import math
import warnings
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import logging
import torch
import random
from encoder_interface import EncoderInterface
from scaling import (
from s import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv1d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
LearnedScale,
)
from torch import Tensor, nn
@ -171,6 +172,7 @@ class ConformerEncoderLayer(nn.Module):
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=dropout,
)
self.self_attn_scale = LearnedScale()
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@ -181,6 +183,7 @@ class ConformerEncoderLayer(nn.Module):
ScaledLinear(dim_feedforward, d_model,
initial_scale=0.1),
)
self.feed_forward_scale = LearnedScale()
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@ -191,11 +194,14 @@ class ConformerEncoderLayer(nn.Module):
ScaledLinear(dim_feedforward, d_model,
initial_scale=0.1),
)
self.feed_forward_macaron_scale = LearnedScale()
self.conv_module = ConvolutionModule(d_model,
cnn_module_kernel)
self.conv_scale = LearnedScale()
self.norm_final = BasicNorm(d_model)
self.final_scale = LearnedScale()
# try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer(
@ -209,11 +215,11 @@ class ConformerEncoderLayer(nn.Module):
def forward(
self,
src: Tensor,
feature_mask: Union[Tensor, float],
pos_emb: Tensor,
attn_scores_in: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
feature_mask: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tensor:
"""
@ -233,10 +239,10 @@ class ConformerEncoderLayer(nn.Module):
Shape:
src: (S, N, E).
feature_mask: float, or (S, N, 1)
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
feature_mask: (S, N, E)
S is the source sequence length, N is the batch size, E is the feature number
"""
src_orig = src
@ -254,7 +260,8 @@ class ConformerEncoderLayer(nn.Module):
alpha = 1.0
# macaron style feed forward module
src = src + self.feed_forward_macaron(src)
src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src),
feature_mask)
# multi-headed self-attention module
src_att, _, attn_scores_out = self.self_attn(
@ -264,25 +271,24 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
src = src + src_att
src = src + self.self_attn_scale(src_att, feature_mask)
# convolution module
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask),
feature_mask)
# feed forward module
src = src + self.feed_forward(src)
src = src + self.feed_forward_scale(self.feed_forward(src),
feature_mask)
src = self.final_scale(src, feature_mask)
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
if feature_mask is not None:
src = src * feature_mask
return src, attn_scores_out
@ -359,23 +365,28 @@ class ConformerEncoder(nn.Module):
feature_mask_dropout_prob = 0.15
feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked.
feature_mask = torch.ones_like(src) # S, N, E
# is_masked_frame is 0 with probability `feature_mask_dropout_prob`
is_masked_frame = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
feature_mask[..., feature_unmasked_dim:] *= is_masked_frame
full_feature_mask = torch.ones_like(src) # S, N, E
# feature_mask is 0 with probability `feature_mask_dropout_prob`
# feature_mask shape: (S, N, 1)
feature_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
full_feature_mask[..., feature_unmasked_dim:] *= feature_mask
else:
feature_mask = None
feature_mask = 1.0
full_feature_mask = 1.0
src = src * full_feature_mask
for i, mod in enumerate(self.layers):
output, attn_scores = mod(
output,
feature_mask,
pos_emb,
attn_scores,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
feature_mask=feature_mask,
warmup=warmup,
)
output = output * full_feature_mask
if i in self.aux_layers:
outputs.append(output)

View File

@ -326,6 +326,30 @@ def ScaledConv1d(*args,
return ans
class LearnedScale(torch.nn.Module):
"""
Module that learns a scale dependent on some kind of mask that is typically going to be 0 or 1
in training. The scale will be 1.0 if the mask is 1.0, but may be a different (learned) value
if the mask value is not 1.0.
The idea is that if we have some kind of feature mask that would always be 1.0 in
test mode but might sometimes be 0.0 in training mode, we might want the multiply
the remaining features by a value dependent on this mask.
"""
def __init__(self):
super(LearnedScale, self).__init__()
self.alpha = nn.Parameter(torch.tensor(0.0))
def forward(self,
x: Tensor,
mask: Tensor):
"""
Mask should either be a number (probably 1.0) or a tensors that broadcasts with x.
"""
if self.training and mask is 1.0:
return x
return x * (1.0 + self.alpha * (1.0 - mask))
class ActivationBalancer(torch.nn.Module):
"""