diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 9de413fb4..4cb244769 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -18,17 +18,18 @@ import copy import math import warnings -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import logging import torch import random from encoder_interface import EncoderInterface -from scaling import ( +from s import ( ActivationBalancer, BasicNorm, DoubleSwish, ScaledConv1d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + LearnedScale, ) from torch import Tensor, nn @@ -171,6 +172,7 @@ class ConformerEncoderLayer(nn.Module): self.self_attn = RelPositionMultiheadAttention( d_model, nhead, dropout=dropout, ) + self.self_attn_scale = LearnedScale() self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -181,6 +183,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.1), ) + self.feed_forward_scale = LearnedScale() self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -191,11 +194,14 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.1), ) + self.feed_forward_macaron_scale = LearnedScale() self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_scale = LearnedScale() self.norm_final = BasicNorm(d_model) + self.final_scale = LearnedScale() # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( @@ -209,11 +215,11 @@ class ConformerEncoderLayer(nn.Module): def forward( self, src: Tensor, + feature_mask: Union[Tensor, float], pos_emb: Tensor, attn_scores_in: Optional[Tensor] = None, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - feature_mask: Optional[Tensor] = None, warmup: float = 1.0, ) -> Tensor: """ @@ -233,10 +239,10 @@ class ConformerEncoderLayer(nn.Module): Shape: src: (S, N, E). + feature_mask: float, or (S, N, 1) pos_emb: (N, 2*S-1, E) src_mask: (S, S). src_key_padding_mask: (N, S). - feature_mask: (S, N, E) S is the source sequence length, N is the batch size, E is the feature number """ src_orig = src @@ -254,7 +260,8 @@ class ConformerEncoderLayer(nn.Module): alpha = 1.0 # macaron style feed forward module - src = src + self.feed_forward_macaron(src) + src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src), + feature_mask) # multi-headed self-attention module src_att, _, attn_scores_out = self.self_attn( @@ -264,25 +271,24 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, ) - src = src + src_att + src = src + self.self_attn_scale(src_att, feature_mask) # convolution module - src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) - - + src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask), + feature_mask) # feed forward module - src = src + self.feed_forward(src) + src = src + self.feed_forward_scale(self.feed_forward(src), + feature_mask) + + src = self.final_scale(src, feature_mask) src = self.norm_final(self.balancer(src)) if alpha != 1.0: src = alpha * src + (1 - alpha) * src_orig - if feature_mask is not None: - src = src * feature_mask - return src, attn_scores_out @@ -359,23 +365,28 @@ class ConformerEncoder(nn.Module): feature_mask_dropout_prob = 0.15 feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked. - feature_mask = torch.ones_like(src) # S, N, E - # is_masked_frame is 0 with probability `feature_mask_dropout_prob` - is_masked_frame = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype) - feature_mask[..., feature_unmasked_dim:] *= is_masked_frame + full_feature_mask = torch.ones_like(src) # S, N, E + # feature_mask is 0 with probability `feature_mask_dropout_prob` + # feature_mask shape: (S, N, 1) + feature_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype) + full_feature_mask[..., feature_unmasked_dim:] *= feature_mask else: - feature_mask = None + feature_mask = 1.0 + full_feature_mask = 1.0 + + src = src * full_feature_mask for i, mod in enumerate(self.layers): output, attn_scores = mod( output, + feature_mask, pos_emb, attn_scores, src_mask=mask, src_key_padding_mask=src_key_padding_mask, - feature_mask=feature_mask, warmup=warmup, ) + output = output * full_feature_mask if i in self.aux_layers: outputs.append(output) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 05bf4fb65..8c3aa6a9d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -326,6 +326,30 @@ def ScaledConv1d(*args, return ans +class LearnedScale(torch.nn.Module): + """ + Module that learns a scale dependent on some kind of mask that is typically going to be 0 or 1 + in training. The scale will be 1.0 if the mask is 1.0, but may be a different (learned) value + if the mask value is not 1.0. + + The idea is that if we have some kind of feature mask that would always be 1.0 in + test mode but might sometimes be 0.0 in training mode, we might want the multiply + the remaining features by a value dependent on this mask. + """ + def __init__(self): + super(LearnedScale, self).__init__() + self.alpha = nn.Parameter(torch.tensor(0.0)) + + def forward(self, + x: Tensor, + mask: Tensor): + """ + Mask should either be a number (probably 1.0) or a tensors that broadcasts with x. + """ + if self.training and mask is 1.0: + return x + return x * (1.0 + self.alpha * (1.0 - mask)) + class ActivationBalancer(torch.nn.Module): """