mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change scale_min of Conv2dSubsampling from .01 to .1; some cosmetic changes/unimportant bugfixes.
This commit is contained in:
parent
d682ecc246
commit
025bcc155d
@ -1384,10 +1384,12 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
# (num_heads, batch_size, seq_len, seq_len) x (num_heads, batch_size, seq_len, head_dim)
|
# (num_heads, batch_size, seq_len, seq_len) x (num_heads, batch_size, seq_len, head_dim)
|
||||||
# -> (num_heads, batch_size, seq_len, head_dim)
|
# -> (num_heads, batch_size, seq_len, head_dim)
|
||||||
bottleneck = torch.matmul(attn_weights, bottleneck)
|
bottleneck = torch.matmul(attn_weights, bottleneck)
|
||||||
bottleneck = self.bottleneck_balancer(bottleneck)
|
|
||||||
bottleneck = self.bottleneck_activation(bottleneck)
|
|
||||||
bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim)
|
bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim)
|
||||||
bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim)
|
bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim)
|
||||||
|
|
||||||
|
bottleneck = self.bottleneck_balancer(bottleneck)
|
||||||
|
bottleneck = self.bottleneck_activation(bottleneck)
|
||||||
scales = self.from_bottleneck_proj(bottleneck)
|
scales = self.from_bottleneck_proj(bottleneck)
|
||||||
|
|
||||||
x = self.in_proj(x)
|
x = self.in_proj(x)
|
||||||
@ -1505,6 +1507,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
|
|
||||||
s = s.unsqueeze(-1).expand(-1, -1, -1, self.ratio).reshape(seq_len, batch_size, num_channels // 2)
|
s = s.unsqueeze(-1).expand(-1, -1, -1, self.ratio).reshape(seq_len, batch_size, num_channels // 2)
|
||||||
|
|
||||||
|
x = self.activation(x) # diagnostics only, it's the identity.
|
||||||
x = x * s
|
x = x * s
|
||||||
|
|
||||||
(seq_len, batch_size, embed_dim) = x.shape
|
(seq_len, batch_size, embed_dim) = x.shape
|
||||||
@ -1517,7 +1520,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
# now x: (num_heads, batch_size, seq_len, head_dim)
|
# now x: (num_heads, batch_size, seq_len, head_dim)
|
||||||
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
||||||
|
|
||||||
x = self.activation(x) # diagnostics only, it's the identity.
|
|
||||||
x = self.whiten(x)
|
x = self.whiten(x)
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
return x
|
return x
|
||||||
@ -1720,7 +1722,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
self.scale = nn.Parameter(torch.ones(out_height * layer3_channels))
|
self.scale = nn.Parameter(torch.ones(out_height * layer3_channels))
|
||||||
self.scale_max = 1.0
|
self.scale_max = 1.0
|
||||||
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.01))
|
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
|
||||||
|
|
||||||
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
|
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
|
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user