mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp236' into scaled_adam_exp242
This commit is contained in:
commit
4da4a3a5df
@ -331,6 +331,8 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
d_model, attention_dim, nhead, pos_dim, dropout=0.0,
|
d_model, attention_dim, nhead, pos_dim, dropout=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.pooling = PoolingModule(d_model)
|
||||||
|
|
||||||
self.feed_forward1 = FeedforwardModule(d_model,
|
self.feed_forward1 = FeedforwardModule(d_model,
|
||||||
feedforward_dim,
|
feedforward_dim,
|
||||||
dropout)
|
dropout)
|
||||||
@ -372,9 +374,9 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# ensure we get grads if self.bypass_scale becomes out of range
|
# ensure we get grads if self.bypass_scale becomes out of range
|
||||||
return self.bypass_scale
|
return self.bypass_scale
|
||||||
# hardcode warmup period for bypass scale
|
# hardcode warmup period for bypass scale
|
||||||
warmup_period = 4000.0
|
warmup_period = 20000.0
|
||||||
initial_clamp_min = 0.5
|
initial_clamp_min = 0.75
|
||||||
final_clamp_min = 0.2
|
final_clamp_min = 0.25
|
||||||
if self.batch_count > warmup_period:
|
if self.batch_count > warmup_period:
|
||||||
clamp_min = final_clamp_min
|
clamp_min = final_clamp_min
|
||||||
else:
|
else:
|
||||||
@ -382,6 +384,21 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
(self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min))
|
(self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min))
|
||||||
return self.bypass_scale.clamp(min=clamp_min, max=1.0)
|
return self.bypass_scale.clamp(min=clamp_min, max=1.0)
|
||||||
|
|
||||||
|
def get_dynamic_dropout_rate(self):
|
||||||
|
# return dropout rate for the dynamic modules (self_attn, pooling, convolution); this
|
||||||
|
# starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable
|
||||||
|
# at the beginning, by making the network focus on the feedforward modules.
|
||||||
|
if torch.jit.is_scripting() or not self.training:
|
||||||
|
return 0.0
|
||||||
|
warmup_period = 2000.0
|
||||||
|
initial_dropout_rate = 0.2
|
||||||
|
final_dropout_rate = 0.0
|
||||||
|
if self.batch_count > warmup_period:
|
||||||
|
return final_dropout_rate
|
||||||
|
else:
|
||||||
|
return (initial_dropout_rate -
|
||||||
|
(initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
@ -411,24 +428,37 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = src + self.feed_forward1(src)
|
src = src + self.feed_forward1(src)
|
||||||
|
|
||||||
|
# dropout rate for submodules that interact with time.
|
||||||
|
dynamic_dropout = self.get_dynamic_dropout_rate()
|
||||||
|
|
||||||
|
# pooling module
|
||||||
|
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||||
|
src = src + self.pooling(src,
|
||||||
|
key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
src_att, attn_weights = self.self_attn(
|
use_self_attn = (random.random() > dynamic_dropout)
|
||||||
src,
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
pos_emb=pos_emb,
|
src_att, attn_weights = self.self_attn(
|
||||||
attn_mask=src_mask,
|
src,
|
||||||
key_padding_mask=src_key_padding_mask,
|
pos_emb=pos_emb,
|
||||||
)
|
attn_mask=src_mask,
|
||||||
src = src + src_att
|
key_padding_mask=src_key_padding_mask,
|
||||||
|
)
|
||||||
|
src = src + src_att
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
|
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||||
|
src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
src = src + self.feed_forward2(src)
|
src = src + self.feed_forward2(src)
|
||||||
|
|
||||||
src = src + self.self_attn.forward2(src, attn_weights)
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
|
src = src + self.self_attn.forward2(src, attn_weights)
|
||||||
|
|
||||||
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
|
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||||
|
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
src = src + self.feed_forward3(src)
|
src = src + self.feed_forward3(src)
|
||||||
|
|
||||||
@ -1397,6 +1427,43 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}")
|
logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingModule(nn.Module):
|
||||||
|
"""
|
||||||
|
Averages the input over the time dimension and project with a square matrix.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
d_model: int):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = ScaledLinear(d_model, d_model,
|
||||||
|
initial_scale=0.1, bias=False)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: Tensor,
|
||||||
|
key_padding_mask):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: a Tensor of shape (T, N, C)
|
||||||
|
key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked
|
||||||
|
positions.
|
||||||
|
Returns:
|
||||||
|
a Tensor of shape (1, N, C)
|
||||||
|
"""
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T)
|
||||||
|
pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True))
|
||||||
|
pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1)
|
||||||
|
# now pooling_mask: (T, N, 1)
|
||||||
|
else:
|
||||||
|
num_frames = x.shape[0]
|
||||||
|
pooling_mask = 1.0 / num_frames
|
||||||
|
|
||||||
|
x = (x * pooling_mask).sum(dim=0, keepdim=True)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class FeedforwardModule(nn.Module):
|
class FeedforwardModule(nn.Module):
|
||||||
"""Feedforward module in Zipformer model.
|
"""Feedforward module in Zipformer model.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user