Merge dropout schedule, 0.3 ... 0.1 over 20k batches

This commit is contained in:
Daniel Povey 2022-12-08 18:18:46 +08:00
commit 3f82ee0783
3 changed files with 24 additions and 10 deletions

View File

@ -1258,7 +1258,15 @@ class TanSwish(torch.nn.Module):
return TanSwishFunction.apply(x)
# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates.
class Dropout2(nn.Module):
def __init__(self, p: FloatLike):
super().__init__()
self.p = p
def forward(self, x: Tensor) -> Tensor:
return torch.nn.functional.dropout(x,
p=float(self.p),
training=self.training)
class SwooshLFunction(torch.autograd.Function):
"""

View File

@ -60,6 +60,7 @@ import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from zipformer import Zipformer
from scaling import ScheduledFloat
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -498,7 +499,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
attention_share_layers=to_int_tuple(params.attention_share_layers),
feedforward_dim=to_int_tuple(params.feedforward_dim),
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
dropout=0.1,
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
)
return encoder

View File

@ -27,6 +27,7 @@ from encoder_interface import EncoderInterface
from scaling import (
ActivationBalancer,
BasicNorm,
Dropout2,
MaxEig,
DoubleSwish,
SwooshL,
@ -109,11 +110,15 @@ class Zipformer(EncoderInterface):
feedforward_dim: Union[int, Tuple[int]] = 1536,
cnn_module_kernel: Union[int, Tuple[int]] = 31,
pos_dim: int = 192,
dropout: float = 0.1,
dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0,
) -> None:
super(Zipformer, self).__init__()
if dropout is None:
dropout = ScheduledFloat((0.0, 0.3),
(20000.0, 0.1))
# this is not the probability of skipping a layer. It is the probability of
# dropping out the "skip module" which allows the model to skip groups of
# encoder stacks; when it's dropped out like this, it means we are forced
@ -385,7 +390,7 @@ class ZipformerEncoderLayer(nn.Module):
pos_head_dim: int,
value_head_dim: int,
feedforward_dim: int,
dropout: float = 0.1,
dropout: FloatLike = 0.1,
cnn_module_kernel: int = 31,
# layer_skip_rate will be overwritten to change warmup begin and end times.
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
@ -950,7 +955,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
"""
def __init__(
self, embed_dim: int,
dropout_rate: float,
dropout_rate: FloatLike,
max_len: int = 1000,
length_factor: float = 1.0,
) -> None:
@ -958,7 +963,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
super(CompactRelPositionalEncoding, self).__init__()
self.embed_dim = embed_dim
assert embed_dim % 2 == 0
self.dropout = torch.nn.Dropout(dropout_rate)
self.dropout = Dropout2(dropout_rate)
self.pe = None
assert length_factor >= 1.0
self.length_factor = length_factor
@ -1417,7 +1422,7 @@ class FeedforwardModule(nn.Module):
def __init__(self,
embed_dim: int,
feedforward_dim: int,
dropout: float):
dropout: FloatLike):
super(FeedforwardModule, self).__init__()
self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
@ -1430,7 +1435,7 @@ class FeedforwardModule(nn.Module):
max_abs=5.0,
min_prob=0.25)
self.activation = SwooshL()
self.dropout = nn.Dropout(dropout)
self.dropout = Dropout2(dropout)
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
initial_scale=0.01,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
@ -1708,7 +1713,7 @@ class Conv2dSubsampling(nn.Module):
layer2_channels: int = 32,
layer3_channels: int = 128,
bottleneck_channels: int = 64,
dropout: float = 0.1,
dropout: FloatLike = 0.1,
) -> None:
"""
Args:
@ -1773,7 +1778,7 @@ class Conv2dSubsampling(nn.Module):
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
self.dropout = nn.Dropout(dropout)
self.dropout = Dropout2(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor: