mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make dropout a schedule starting at 0.3.
This commit is contained in:
parent
12fb2081b1
commit
22617da725
@ -1212,7 +1212,15 @@ class TanSwish(torch.nn.Module):
|
|||||||
return TanSwishFunction.apply(x)
|
return TanSwishFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates.
|
||||||
|
class Dropout2(nn.Module):
|
||||||
|
def __init__(self, p: FloatLike):
|
||||||
|
super().__init__()
|
||||||
|
self.p = p
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return torch.nn.functional.dropout(x,
|
||||||
|
p=float(self.p),
|
||||||
|
training=self.training)
|
||||||
|
|
||||||
class SwooshLFunction(torch.autograd.Function):
|
class SwooshLFunction(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from encoder_interface import EncoderInterface
|
|||||||
from scaling import (
|
from scaling import (
|
||||||
ActivationBalancer,
|
ActivationBalancer,
|
||||||
BasicNorm,
|
BasicNorm,
|
||||||
|
Dropout2,
|
||||||
MaxEig,
|
MaxEig,
|
||||||
DoubleSwish,
|
DoubleSwish,
|
||||||
SwooshL,
|
SwooshL,
|
||||||
@ -107,11 +108,15 @@ class Zipformer(EncoderInterface):
|
|||||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||||
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
||||||
pos_dim: int = 192,
|
pos_dim: int = 192,
|
||||||
dropout: float = 0.1,
|
dropout: FloatLike = None, # see code below for default
|
||||||
warmup_batches: float = 4000.0,
|
warmup_batches: float = 4000.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Zipformer, self).__init__()
|
super(Zipformer, self).__init__()
|
||||||
|
|
||||||
|
if dropout is None:
|
||||||
|
dropout = ScheduledFloat((0.0, 0.3),
|
||||||
|
(20000.0, 0.1))
|
||||||
|
|
||||||
# this is not the probability of skipping a layer. It is the probability of
|
# this is not the probability of skipping a layer. It is the probability of
|
||||||
# dropping out the "skip module" which allows the model to skip groups of
|
# dropping out the "skip module" which allows the model to skip groups of
|
||||||
# encoder stacks; when it's dropped out like this, it means we are forced
|
# encoder stacks; when it's dropped out like this, it means we are forced
|
||||||
@ -383,7 +388,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
pos_head_dim: int,
|
pos_head_dim: int,
|
||||||
value_head_dim: int,
|
value_head_dim: int,
|
||||||
feedforward_dim: int,
|
feedforward_dim: int,
|
||||||
dropout: float = 0.1,
|
dropout: FloatLike = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
# layer_skip_rate will be overwritten to change warmup begin and end times.
|
# layer_skip_rate will be overwritten to change warmup begin and end times.
|
||||||
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||||
@ -948,7 +953,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, embed_dim: int,
|
self, embed_dim: int,
|
||||||
dropout_rate: float,
|
dropout_rate: FloatLike,
|
||||||
max_len: int = 1000,
|
max_len: int = 1000,
|
||||||
length_factor: float = 1.0,
|
length_factor: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -956,7 +961,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
super(CompactRelPositionalEncoding, self).__init__()
|
super(CompactRelPositionalEncoding, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
assert embed_dim % 2 == 0
|
assert embed_dim % 2 == 0
|
||||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
self.dropout = Dropout2(dropout_rate)
|
||||||
self.pe = None
|
self.pe = None
|
||||||
assert length_factor >= 1.0
|
assert length_factor >= 1.0
|
||||||
self.length_factor = length_factor
|
self.length_factor = length_factor
|
||||||
@ -1415,7 +1420,7 @@ class FeedforwardModule(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
feedforward_dim: int,
|
feedforward_dim: int,
|
||||||
dropout: float):
|
dropout: FloatLike):
|
||||||
super(FeedforwardModule, self).__init__()
|
super(FeedforwardModule, self).__init__()
|
||||||
self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim,
|
self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim,
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
|
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
|
||||||
@ -1428,7 +1433,7 @@ class FeedforwardModule(nn.Module):
|
|||||||
max_abs=5.0,
|
max_abs=5.0,
|
||||||
min_prob=0.25)
|
min_prob=0.25)
|
||||||
self.activation = SwooshL()
|
self.activation = SwooshL()
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = Dropout2(dropout)
|
||||||
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
|
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
|
||||||
initial_scale=0.01,
|
initial_scale=0.01,
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
|
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
|
||||||
@ -1684,7 +1689,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
layer2_channels: int = 32,
|
layer2_channels: int = 32,
|
||||||
layer3_channels: int = 128,
|
layer3_channels: int = 128,
|
||||||
bottleneck_channels: int = 64,
|
bottleneck_channels: int = 64,
|
||||||
dropout: float = 0.1,
|
dropout: FloatLike = 0.1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -1742,7 +1747,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
|
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
|
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = Dropout2(dropout)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user