diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 1fc46259b..c3a652e8a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1212,7 +1212,15 @@ class TanSwish(torch.nn.Module): return TanSwishFunction.apply(x) - +# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. +class Dropout2(nn.Module): + def __init__(self, p: FloatLike): + super().__init__() + self.p = p + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.dropout(x, + p=float(self.p), + training=self.training) class SwooshLFunction(torch.autograd.Function): """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 038da0136..f01b9e8fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -27,6 +27,7 @@ from encoder_interface import EncoderInterface from scaling import ( ActivationBalancer, BasicNorm, + Dropout2, MaxEig, DoubleSwish, SwooshL, @@ -107,11 +108,15 @@ class Zipformer(EncoderInterface): feedforward_dim: Union[int, Tuple[int]] = 1536, cnn_module_kernel: Union[int, Tuple[int]] = 31, pos_dim: int = 192, - dropout: float = 0.1, + dropout: FloatLike = None, # see code below for default warmup_batches: float = 4000.0, ) -> None: super(Zipformer, self).__init__() + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), + (20000.0, 0.1)) + # this is not the probability of skipping a layer. It is the probability of # dropping out the "skip module" which allows the model to skip groups of # encoder stacks; when it's dropped out like this, it means we are forced @@ -383,7 +388,7 @@ class ZipformerEncoderLayer(nn.Module): pos_head_dim: int, value_head_dim: int, feedforward_dim: int, - dropout: float = 0.1, + dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, # layer_skip_rate will be overwritten to change warmup begin and end times. # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() @@ -948,7 +953,7 @@ class CompactRelPositionalEncoding(torch.nn.Module): """ def __init__( self, embed_dim: int, - dropout_rate: float, + dropout_rate: FloatLike, max_len: int = 1000, length_factor: float = 1.0, ) -> None: @@ -956,7 +961,7 @@ class CompactRelPositionalEncoding(torch.nn.Module): super(CompactRelPositionalEncoding, self).__init__() self.embed_dim = embed_dim assert embed_dim % 2 == 0 - self.dropout = torch.nn.Dropout(dropout_rate) + self.dropout = Dropout2(dropout_rate) self.pe = None assert length_factor >= 1.0 self.length_factor = length_factor @@ -1415,7 +1420,7 @@ class FeedforwardModule(nn.Module): def __init__(self, embed_dim: int, feedforward_dim: int, - dropout: float): + dropout: FloatLike): super(FeedforwardModule, self).__init__() self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim, aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()) @@ -1428,7 +1433,7 @@ class FeedforwardModule(nn.Module): max_abs=5.0, min_prob=0.25) self.activation = SwooshL() - self.dropout = nn.Dropout(dropout) + self.dropout = Dropout2(dropout) self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim, initial_scale=0.01, aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out()) @@ -1684,7 +1689,7 @@ class Conv2dSubsampling(nn.Module): layer2_channels: int = 32, layer3_channels: int = 128, bottleneck_channels: int = 64, - dropout: float = 0.1, + dropout: FloatLike = 0.1, ) -> None: """ Args: @@ -1742,7 +1747,7 @@ class Conv2dSubsampling(nn.Module): self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels, aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out()) - self.dropout = nn.Dropout(dropout) + self.dropout = Dropout2(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: