mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove warmup
This commit is contained in:
parent
bb233d3449
commit
537c3537c0
@ -94,7 +94,7 @@ class Conformer(EncoderInterface):
|
||||
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -103,10 +103,6 @@ class Conformer(EncoderInterface):
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
warmup:
|
||||
A floating point value that gradually increases from 0 throughout
|
||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||
to turn modules on sequentially.
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- embeddings: its shape is (batch_size, output_seq_len, d_model)
|
||||
@ -125,7 +121,7 @@ class Conformer(EncoderInterface):
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
x = self.encoder(
|
||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||
x, pos_emb, src_key_padding_mask=mask,
|
||||
) # (T, N, C)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
@ -174,7 +170,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model,
|
||||
initial_scale=0.1),
|
||||
initial_scale=0.01),
|
||||
)
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
@ -184,7 +180,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model,
|
||||
initial_scale=0.1),
|
||||
initial_scale=0.01),
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model,
|
||||
@ -207,7 +203,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
attn_scores_in: Optional[Tensor] = None,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
layerdrop_scale: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
@ -220,8 +215,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
passed from layer to layer.
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
batch_split: if not None, this layer will only be applied to
|
||||
layerdrop_scale: an optional Tensor of broadcasting with `src` that will be used as a scale
|
||||
on the change in the embeddings made by this layer.
|
||||
@ -235,11 +228,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
src_orig = src
|
||||
|
||||
warmup_scale = min(0.1 + warmup, 1.0)
|
||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||
# completely bypass it.
|
||||
alpha = warmup_scale if self.training else 1.0
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.feed_forward_macaron(src)
|
||||
|
||||
@ -262,13 +250,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0 or layerdrop_scale is not None:
|
||||
# the if(self.training) part is to ensure we have a derivative for
|
||||
# self.scale_alpha.
|
||||
src_offset = src - src_orig
|
||||
scale = alpha * layerdrop_scale
|
||||
src = src_orig + src_offset * scale
|
||||
|
||||
return src, attn_scores_out
|
||||
|
||||
|
||||
@ -383,7 +364,6 @@ class ConformerEncoder(nn.Module):
|
||||
pos_emb: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
@ -437,7 +417,6 @@ class ConformerEncoder(nn.Module):
|
||||
attn_scores,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
layerdrop_scale=layerdrop_scales[i],
|
||||
)
|
||||
output = output * feature_mask
|
||||
@ -564,7 +543,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
channel_dim=-1, max_abs=10.0,
|
||||
min_positive=0.0, max_positive=1.0)
|
||||
self.out_proj = ScaledLinear(
|
||||
embed_dim // 2, embed_dim, bias=True, initial_scale=0.5
|
||||
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||
)
|
||||
|
||||
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
||||
@ -982,7 +961,7 @@ class ConvolutionModule(nn.Module):
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
initial_scale=0.5,
|
||||
initial_scale=0.05,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
@ -1257,14 +1236,12 @@ def _test_conformer_main():
|
||||
f = c(
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
warmup=0.5,
|
||||
)
|
||||
f # to remove flake8 warnings
|
||||
c.eval()
|
||||
f = c(
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
warmup=0.5,
|
||||
)
|
||||
f # to remove flake8 warnings
|
||||
|
||||
|
||||
@ -75,7 +75,6 @@ class Transducer(nn.Module):
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
warmup: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -96,9 +95,6 @@ class Transducer(nn.Module):
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
warmup:
|
||||
A value warmup >= 0 that determines which modules are active, values
|
||||
warmup > 1 "are fully warmed up" and all modules will be active.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
@ -114,7 +110,7 @@ class Transducer(nn.Module):
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
|
||||
@ -619,7 +619,6 @@ def compute_loss(
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
)
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (model_warm_step), to avoid
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user