Make dropout a schedule starting at 0.3.

This commit is contained in:
Daniel Povey 2022-12-05 23:39:24 +08:00
parent 12fb2081b1
commit 22617da725
2 changed files with 22 additions and 9 deletions

View File

@ -1212,7 +1212,15 @@ class TanSwish(torch.nn.Module):
return TanSwishFunction.apply(x) return TanSwishFunction.apply(x)
# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates.
class Dropout2(nn.Module):
def __init__(self, p: FloatLike):
super().__init__()
self.p = p
def forward(self, x: Tensor) -> Tensor:
return torch.nn.functional.dropout(x,
p=float(self.p),
training=self.training)
class SwooshLFunction(torch.autograd.Function): class SwooshLFunction(torch.autograd.Function):
""" """

View File

@ -27,6 +27,7 @@ from encoder_interface import EncoderInterface
from scaling import ( from scaling import (
ActivationBalancer, ActivationBalancer,
BasicNorm, BasicNorm,
Dropout2,
MaxEig, MaxEig,
DoubleSwish, DoubleSwish,
SwooshL, SwooshL,
@ -107,11 +108,15 @@ class Zipformer(EncoderInterface):
feedforward_dim: Union[int, Tuple[int]] = 1536, feedforward_dim: Union[int, Tuple[int]] = 1536,
cnn_module_kernel: Union[int, Tuple[int]] = 31, cnn_module_kernel: Union[int, Tuple[int]] = 31,
pos_dim: int = 192, pos_dim: int = 192,
dropout: float = 0.1, dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0, warmup_batches: float = 4000.0,
) -> None: ) -> None:
super(Zipformer, self).__init__() super(Zipformer, self).__init__()
if dropout is None:
dropout = ScheduledFloat((0.0, 0.3),
(20000.0, 0.1))
# this is not the probability of skipping a layer. It is the probability of # this is not the probability of skipping a layer. It is the probability of
# dropping out the "skip module" which allows the model to skip groups of # dropping out the "skip module" which allows the model to skip groups of
# encoder stacks; when it's dropped out like this, it means we are forced # encoder stacks; when it's dropped out like this, it means we are forced
@ -383,7 +388,7 @@ class ZipformerEncoderLayer(nn.Module):
pos_head_dim: int, pos_head_dim: int,
value_head_dim: int, value_head_dim: int,
feedforward_dim: int, feedforward_dim: int,
dropout: float = 0.1, dropout: FloatLike = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
# layer_skip_rate will be overwritten to change warmup begin and end times. # layer_skip_rate will be overwritten to change warmup begin and end times.
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
@ -948,7 +953,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
""" """
def __init__( def __init__(
self, embed_dim: int, self, embed_dim: int,
dropout_rate: float, dropout_rate: FloatLike,
max_len: int = 1000, max_len: int = 1000,
length_factor: float = 1.0, length_factor: float = 1.0,
) -> None: ) -> None:
@ -956,7 +961,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
super(CompactRelPositionalEncoding, self).__init__() super(CompactRelPositionalEncoding, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
assert embed_dim % 2 == 0 assert embed_dim % 2 == 0
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = Dropout2(dropout_rate)
self.pe = None self.pe = None
assert length_factor >= 1.0 assert length_factor >= 1.0
self.length_factor = length_factor self.length_factor = length_factor
@ -1415,7 +1420,7 @@ class FeedforwardModule(nn.Module):
def __init__(self, def __init__(self,
embed_dim: int, embed_dim: int,
feedforward_dim: int, feedforward_dim: int,
dropout: float): dropout: FloatLike):
super(FeedforwardModule, self).__init__() super(FeedforwardModule, self).__init__()
self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim, self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()) aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
@ -1428,7 +1433,7 @@ class FeedforwardModule(nn.Module):
max_abs=5.0, max_abs=5.0,
min_prob=0.25) min_prob=0.25)
self.activation = SwooshL() self.activation = SwooshL()
self.dropout = nn.Dropout(dropout) self.dropout = Dropout2(dropout)
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim, self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
initial_scale=0.01, initial_scale=0.01,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out()) aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
@ -1684,7 +1689,7 @@ class Conv2dSubsampling(nn.Module):
layer2_channels: int = 32, layer2_channels: int = 32,
layer3_channels: int = 128, layer3_channels: int = 128,
bottleneck_channels: int = 64, bottleneck_channels: int = 64,
dropout: float = 0.1, dropout: FloatLike = 0.1,
) -> None: ) -> None:
""" """
Args: Args:
@ -1742,7 +1747,7 @@ class Conv2dSubsampling(nn.Module):
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels, self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out()) aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
self.dropout = nn.Dropout(dropout) self.dropout = Dropout2(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: