diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index c177b1d78..308e2e13d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -94,7 +94,7 @@ class Conformer(EncoderInterface): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 + self, x: torch.Tensor, x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -103,10 +103,6 @@ class Conformer(EncoderInterface): x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. - warmup: - A floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. Returns: Return a tuple containing 2 tensors: - embeddings: its shape is (batch_size, output_seq_len, d_model) @@ -125,7 +121,7 @@ class Conformer(EncoderInterface): mask = make_pad_mask(lengths) x = self.encoder( - x, pos_emb, src_key_padding_mask=mask, warmup=warmup + x, pos_emb, src_key_padding_mask=mask, ) # (T, N, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -174,7 +170,7 @@ class ConformerEncoderLayer(nn.Module): DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, - initial_scale=0.1), + initial_scale=0.01), ) self.feed_forward_macaron = nn.Sequential( @@ -184,7 +180,7 @@ class ConformerEncoderLayer(nn.Module): DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, - initial_scale=0.1), + initial_scale=0.01), ) self.conv_module = ConvolutionModule(d_model, @@ -207,7 +203,6 @@ class ConformerEncoderLayer(nn.Module): attn_scores_in: Optional[Tensor] = None, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - warmup: float = 1.0, layerdrop_scale: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """ @@ -220,8 +215,6 @@ class ConformerEncoderLayer(nn.Module): passed from layer to layer. src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. batch_split: if not None, this layer will only be applied to layerdrop_scale: an optional Tensor of broadcasting with `src` that will be used as a scale on the change in the embeddings made by this layer. @@ -235,11 +228,6 @@ class ConformerEncoderLayer(nn.Module): """ src_orig = src - warmup_scale = min(0.1 + warmup, 1.0) - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean - # completely bypass it. - alpha = warmup_scale if self.training else 1.0 - # macaron style feed forward module src = src + self.feed_forward_macaron(src) @@ -262,13 +250,6 @@ class ConformerEncoderLayer(nn.Module): src = self.norm_final(self.balancer(src)) - if alpha != 1.0 or layerdrop_scale is not None: - # the if(self.training) part is to ensure we have a derivative for - # self.scale_alpha. - src_offset = src - src_orig - scale = alpha * layerdrop_scale - src = src_orig + src_offset * scale - return src, attn_scores_out @@ -383,7 +364,6 @@ class ConformerEncoder(nn.Module): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - warmup: float = 1.0, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -437,7 +417,6 @@ class ConformerEncoder(nn.Module): attn_scores, src_mask=mask, src_key_padding_mask=src_key_padding_mask, - warmup=warmup, layerdrop_scale=layerdrop_scales[i], ) output = output * feature_mask @@ -564,7 +543,7 @@ class RelPositionMultiheadAttention(nn.Module): channel_dim=-1, max_abs=10.0, min_positive=0.0, max_positive=1.0) self.out_proj = ScaledLinear( - embed_dim // 2, embed_dim, bias=True, initial_scale=0.5 + embed_dim // 2, embed_dim, bias=True, initial_scale=0.05 ) self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads)) @@ -982,7 +961,7 @@ class ConvolutionModule(nn.Module): stride=1, padding=0, bias=bias, - initial_scale=0.5, + initial_scale=0.05, ) def forward(self, @@ -1257,14 +1236,12 @@ def _test_conformer_main(): f = c( torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup=0.5, ) f # to remove flake8 warnings c.eval() f = c( torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup=0.5, ) f # to remove flake8 warnings diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 24898ed09..ee88a9159 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -75,7 +75,6 @@ class Transducer(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - warmup: float = 1.0, ) -> torch.Tensor: """ Args: @@ -96,9 +95,6 @@ class Transducer(nn.Module): lm_scale: The scale to smooth the loss with lm (output of predictor network) part - warmup: - A value warmup >= 0 that determines which modules are active, values - warmup > 1 "are fully warmed up" and all modules will be active. Returns: Return the transducer loss. @@ -114,7 +110,7 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) + encoder_out, x_lens = self.encoder(x, x_lens) assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 176d8f207..d3680e75e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -619,7 +619,6 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, - warmup=warmup, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid