mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp725' into scaled_adam_exp736
# Conflicts: # egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py # egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
This commit is contained in:
commit
6277a5ab4b
@ -210,7 +210,7 @@ class Zipformer(EncoderInterface):
|
||||
# initializes self.skip_layers and self.skip_modules
|
||||
self._init_skip_modules()
|
||||
|
||||
self.downsample_output = AttentionDownsample(encoder_dim[-1],
|
||||
self.downsample_output = SimpleDownsample(encoder_dim[-1],
|
||||
encoder_dim[-1],
|
||||
downsample=output_downsampling_factor,
|
||||
dropout=dropout)
|
||||
@ -678,7 +678,7 @@ class DownsampledZipformerEncoder(nn.Module):
|
||||
dropout: FloatLike):
|
||||
super(DownsampledZipformerEncoder, self).__init__()
|
||||
self.downsample_factor = downsample
|
||||
self.downsample = AttentionDownsample(input_dim, output_dim,
|
||||
self.downsample = SimpleDownsample(input_dim, output_dim,
|
||||
downsample, dropout)
|
||||
self.encoder = encoder
|
||||
self.upsample = SimpleUpsample(output_dim, downsample)
|
||||
@ -731,7 +731,7 @@ class DownsampledZipformerEncoder(nn.Module):
|
||||
|
||||
|
||||
|
||||
class AttentionDownsample(torch.nn.Module):
|
||||
class SimpleDownsample(torch.nn.Module):
|
||||
"""
|
||||
Does downsampling with attention, by weighted sum, and a projection..
|
||||
"""
|
||||
@ -743,8 +743,8 @@ class AttentionDownsample(torch.nn.Module):
|
||||
"""
|
||||
Require out_channels > in_channels.
|
||||
"""
|
||||
super(AttentionDownsample, self).__init__()
|
||||
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
||||
super(SimpleDownsample, self).__init__()
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(downsample))
|
||||
|
||||
self.name = None # will be set from training code
|
||||
@ -779,24 +779,10 @@ class AttentionDownsample(torch.nn.Module):
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
|
||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||
# scores: (d_seq_len, downsample, batch_size, 1)
|
||||
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
||||
scores = scores + self.bias.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
scores = penalize_abs_values_gt(scores,
|
||||
limit=20.0,
|
||||
penalty=1.0e-04,
|
||||
name=self.name)
|
||||
|
||||
dropout = float(self.dropout)
|
||||
if dropout > 0.0:
|
||||
# the 0:1, done on the axis of size 'downsample', selects just
|
||||
# one dimension while keeping the dim. We'll then broadcast when
|
||||
# we multiply.
|
||||
dropout_mask = torch.rand_like(scores[:, 0:1]) > dropout
|
||||
scores = scores * dropout_mask
|
||||
|
||||
weights = scores.softmax(dim=1)
|
||||
weights = self.bias.softmax(dim=0)
|
||||
# weights: (downsample, 1, 1)
|
||||
weights = weights.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# ans1 is the first `in_channels` channels of the output
|
||||
ans = (src * weights).sum(dim=1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user