Remove warmup
This commit is contained in:
parent
bb233d3449
commit
537c3537c0
@ -94,7 +94,7 @@ class Conformer(EncoderInterface):
|
|||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
self, x: torch.Tensor, x_lens: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -103,10 +103,6 @@ class Conformer(EncoderInterface):
|
|||||||
x_lens:
|
x_lens:
|
||||||
A tensor of shape (batch_size,) containing the number of frames in
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
`x` before padding.
|
`x` before padding.
|
||||||
warmup:
|
|
||||||
A floating point value that gradually increases from 0 throughout
|
|
||||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
|
||||||
to turn modules on sequentially.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing 2 tensors:
|
Return a tuple containing 2 tensors:
|
||||||
- embeddings: its shape is (batch_size, output_seq_len, d_model)
|
- embeddings: its shape is (batch_size, output_seq_len, d_model)
|
||||||
@ -125,7 +121,7 @@ class Conformer(EncoderInterface):
|
|||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
x = self.encoder(
|
x = self.encoder(
|
||||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
x, pos_emb, src_key_padding_mask=mask,
|
||||||
) # (T, N, C)
|
) # (T, N, C)
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
@ -174,7 +170,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model,
|
ScaledLinear(dim_feedforward, d_model,
|
||||||
initial_scale=0.1),
|
initial_scale=0.01),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
@ -184,7 +180,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model,
|
ScaledLinear(dim_feedforward, d_model,
|
||||||
initial_scale=0.1),
|
initial_scale=0.01),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model,
|
self.conv_module = ConvolutionModule(d_model,
|
||||||
@ -207,7 +203,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
attn_scores_in: Optional[Tensor] = None,
|
attn_scores_in: Optional[Tensor] = None,
|
||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0,
|
|
||||||
layerdrop_scale: Optional[Tensor] = None,
|
layerdrop_scale: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
@ -220,8 +215,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
passed from layer to layer.
|
passed from layer to layer.
|
||||||
src_mask: the mask for the src sequence (optional).
|
src_mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
|
||||||
bypass layers more frequently.
|
|
||||||
batch_split: if not None, this layer will only be applied to
|
batch_split: if not None, this layer will only be applied to
|
||||||
layerdrop_scale: an optional Tensor of broadcasting with `src` that will be used as a scale
|
layerdrop_scale: an optional Tensor of broadcasting with `src` that will be used as a scale
|
||||||
on the change in the embeddings made by this layer.
|
on the change in the embeddings made by this layer.
|
||||||
@ -235,11 +228,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
warmup_scale = min(0.1 + warmup, 1.0)
|
|
||||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
|
||||||
# completely bypass it.
|
|
||||||
alpha = warmup_scale if self.training else 1.0
|
|
||||||
|
|
||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = src + self.feed_forward_macaron(src)
|
src = src + self.feed_forward_macaron(src)
|
||||||
|
|
||||||
@ -262,13 +250,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
if alpha != 1.0 or layerdrop_scale is not None:
|
|
||||||
# the if(self.training) part is to ensure we have a derivative for
|
|
||||||
# self.scale_alpha.
|
|
||||||
src_offset = src - src_orig
|
|
||||||
scale = alpha * layerdrop_scale
|
|
||||||
src = src_orig + src_offset * scale
|
|
||||||
|
|
||||||
return src, attn_scores_out
|
return src, attn_scores_out
|
||||||
|
|
||||||
|
|
||||||
@ -383,7 +364,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
@ -437,7 +417,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
attn_scores,
|
attn_scores,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
warmup=warmup,
|
|
||||||
layerdrop_scale=layerdrop_scales[i],
|
layerdrop_scale=layerdrop_scales[i],
|
||||||
)
|
)
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
@ -564,7 +543,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
channel_dim=-1, max_abs=10.0,
|
channel_dim=-1, max_abs=10.0,
|
||||||
min_positive=0.0, max_positive=1.0)
|
min_positive=0.0, max_positive=1.0)
|
||||||
self.out_proj = ScaledLinear(
|
self.out_proj = ScaledLinear(
|
||||||
embed_dim // 2, embed_dim, bias=True, initial_scale=0.5
|
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
||||||
@ -982,7 +961,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
initial_scale=0.5,
|
initial_scale=0.05,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -1257,14 +1236,12 @@ def _test_conformer_main():
|
|||||||
f = c(
|
f = c(
|
||||||
torch.randn(batch_size, seq_len, feature_dim),
|
torch.randn(batch_size, seq_len, feature_dim),
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
warmup=0.5,
|
|
||||||
)
|
)
|
||||||
f # to remove flake8 warnings
|
f # to remove flake8 warnings
|
||||||
c.eval()
|
c.eval()
|
||||||
f = c(
|
f = c(
|
||||||
torch.randn(batch_size, seq_len, feature_dim),
|
torch.randn(batch_size, seq_len, feature_dim),
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
warmup=0.5,
|
|
||||||
)
|
)
|
||||||
f # to remove flake8 warnings
|
f # to remove flake8 warnings
|
||||||
|
|
||||||
|
|||||||
@ -75,7 +75,6 @@ class Transducer(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
warmup: float = 1.0,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -96,9 +95,6 @@ class Transducer(nn.Module):
|
|||||||
lm_scale:
|
lm_scale:
|
||||||
The scale to smooth the loss with lm (output of predictor network)
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
part
|
part
|
||||||
warmup:
|
|
||||||
A value warmup >= 0 that determines which modules are active, values
|
|
||||||
warmup > 1 "are fully warmed up" and all modules will be active.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
@ -114,7 +110,7 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
|
|||||||
@ -619,7 +619,6 @@ def compute_loss(
|
|||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
|
||||||
)
|
)
|
||||||
# after the main warmup step, we keep pruned_loss_scale small
|
# after the main warmup step, we keep pruned_loss_scale small
|
||||||
# for the same amount of time (model_warm_step), to avoid
|
# for the same amount of time (model_warm_step), to avoid
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user