Remove warmup

This commit is contained in:
Daniel Povey 2022-10-06 12:33:43 +08:00
parent bb233d3449
commit 537c3537c0
3 changed files with 7 additions and 35 deletions

View File

@ -94,7 +94,7 @@ class Conformer(EncoderInterface):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
self, x: torch.Tensor, x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@ -103,10 +103,6 @@ class Conformer(EncoderInterface):
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
Return a tuple containing 2 tensors:
- embeddings: its shape is (batch_size, output_seq_len, d_model)
@ -125,7 +121,7 @@ class Conformer(EncoderInterface):
mask = make_pad_mask(lengths)
x = self.encoder(
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
x, pos_emb, src_key_padding_mask=mask,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -174,7 +170,7 @@ class ConformerEncoderLayer(nn.Module):
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model,
initial_scale=0.1),
initial_scale=0.01),
)
self.feed_forward_macaron = nn.Sequential(
@ -184,7 +180,7 @@ class ConformerEncoderLayer(nn.Module):
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model,
initial_scale=0.1),
initial_scale=0.01),
)
self.conv_module = ConvolutionModule(d_model,
@ -207,7 +203,6 @@ class ConformerEncoderLayer(nn.Module):
attn_scores_in: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
layerdrop_scale: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
@ -220,8 +215,6 @@ class ConformerEncoderLayer(nn.Module):
passed from layer to layer.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
batch_split: if not None, this layer will only be applied to
layerdrop_scale: an optional Tensor of broadcasting with `src` that will be used as a scale
on the change in the embeddings made by this layer.
@ -235,11 +228,6 @@ class ConformerEncoderLayer(nn.Module):
"""
src_orig = src
warmup_scale = min(0.1 + warmup, 1.0)
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it.
alpha = warmup_scale if self.training else 1.0
# macaron style feed forward module
src = src + self.feed_forward_macaron(src)
@ -262,13 +250,6 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_final(self.balancer(src))
if alpha != 1.0 or layerdrop_scale is not None:
# the if(self.training) part is to ensure we have a derivative for
# self.scale_alpha.
src_offset = src - src_orig
scale = alpha * layerdrop_scale
src = src_orig + src_offset * scale
return src, attn_scores_out
@ -383,7 +364,6 @@ class ConformerEncoder(nn.Module):
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
@ -437,7 +417,6 @@ class ConformerEncoder(nn.Module):
attn_scores,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
layerdrop_scale=layerdrop_scales[i],
)
output = output * feature_mask
@ -564,7 +543,7 @@ class RelPositionMultiheadAttention(nn.Module):
channel_dim=-1, max_abs=10.0,
min_positive=0.0, max_positive=1.0)
self.out_proj = ScaledLinear(
embed_dim // 2, embed_dim, bias=True, initial_scale=0.5
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
)
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
@ -982,7 +961,7 @@ class ConvolutionModule(nn.Module):
stride=1,
padding=0,
bias=bias,
initial_scale=0.5,
initial_scale=0.05,
)
def forward(self,
@ -1257,14 +1236,12 @@ def _test_conformer_main():
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)
f # to remove flake8 warnings
c.eval()
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)
f # to remove flake8 warnings

View File

@ -75,7 +75,6 @@ class Transducer(nn.Module):
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
) -> torch.Tensor:
"""
Args:
@ -96,9 +95,6 @@ class Transducer(nn.Module):
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
Returns:
Return the transducer loss.
@ -114,7 +110,7 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network

View File

@ -619,7 +619,6 @@ def compute_loss(
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid