mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement dropout for scores in AttentionDownsample
This commit is contained in:
parent
2f617fec43
commit
cb12014c31
@ -202,6 +202,7 @@ class Zipformer(EncoderInterface):
|
|||||||
input_dim=encoder_dim[i-1] if i > 0 else encoder_dim[0],
|
input_dim=encoder_dim[i-1] if i > 0 else encoder_dim[0],
|
||||||
output_dim=encoder_dim[i],
|
output_dim=encoder_dim[i],
|
||||||
downsample=downsampling_factor[i],
|
downsample=downsampling_factor[i],
|
||||||
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
encoders.append(encoder)
|
encoders.append(encoder)
|
||||||
self.encoders = nn.ModuleList(encoders)
|
self.encoders = nn.ModuleList(encoders)
|
||||||
@ -211,7 +212,8 @@ class Zipformer(EncoderInterface):
|
|||||||
|
|
||||||
self.downsample_output = AttentionDownsample(encoder_dim[-1],
|
self.downsample_output = AttentionDownsample(encoder_dim[-1],
|
||||||
encoder_dim[-1],
|
encoder_dim[-1],
|
||||||
downsample=output_downsampling_factor)
|
downsample=output_downsampling_factor,
|
||||||
|
dropout=dropout)
|
||||||
|
|
||||||
|
|
||||||
def _init_skip_modules(self):
|
def _init_skip_modules(self):
|
||||||
@ -677,10 +679,12 @@ class DownsampledZipformerEncoder(nn.Module):
|
|||||||
encoder: nn.Module,
|
encoder: nn.Module,
|
||||||
input_dim: int,
|
input_dim: int,
|
||||||
output_dim: int,
|
output_dim: int,
|
||||||
downsample: int):
|
downsample: int,
|
||||||
|
dropout: FloatLike):
|
||||||
super(DownsampledZipformerEncoder, self).__init__()
|
super(DownsampledZipformerEncoder, self).__init__()
|
||||||
self.downsample_factor = downsample
|
self.downsample_factor = downsample
|
||||||
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
|
self.downsample = AttentionDownsample(input_dim, output_dim,
|
||||||
|
downsample, dropout)
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.upsample = SimpleUpsample(output_dim, downsample)
|
self.upsample = SimpleUpsample(output_dim, downsample)
|
||||||
self.out_combiner = SimpleCombiner(input_dim,
|
self.out_combiner = SimpleCombiner(input_dim,
|
||||||
@ -794,7 +798,8 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
downsample: int):
|
downsample: int,
|
||||||
|
dropout: FloatLike):
|
||||||
"""
|
"""
|
||||||
Require out_channels > in_channels.
|
Require out_channels > in_channels.
|
||||||
"""
|
"""
|
||||||
@ -802,6 +807,7 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
||||||
|
|
||||||
self.name = None # will be set from training code
|
self.name = None # will be set from training code
|
||||||
|
self.dropout = copy.deepcopy(dropout)
|
||||||
|
|
||||||
# fill in the extra dimensions with a projection of the input
|
# fill in the extra dimensions with a projection of the input
|
||||||
if out_channels > in_channels:
|
if out_channels > in_channels:
|
||||||
@ -832,6 +838,7 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
assert src.shape[0] == d_seq_len * ds
|
assert src.shape[0] == d_seq_len * ds
|
||||||
|
|
||||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||||
|
# scores: (d_seq_len, downsample, batch_size)
|
||||||
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
scores = penalize_abs_values_gt(scores,
|
scores = penalize_abs_values_gt(scores,
|
||||||
@ -839,6 +846,14 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
penalty=1.0e-04,
|
penalty=1.0e-04,
|
||||||
name=self.name)
|
name=self.name)
|
||||||
|
|
||||||
|
dropout = float(self.dropout)
|
||||||
|
if dropout > 0.0:
|
||||||
|
# the 0:1, done on the axis of size 'downsample', selects just
|
||||||
|
# one dimension while keeping the dim. We'll then broadcast when
|
||||||
|
# we multiply.
|
||||||
|
dropout_mask = torch.rand_like(scores[:, 0:1]) > dropout
|
||||||
|
scores = scores * dropout_mask
|
||||||
|
|
||||||
weights = scores.softmax(dim=1)
|
weights = scores.softmax(dim=1)
|
||||||
|
|
||||||
# ans1 is the first `in_channels` channels of the output
|
# ans1 is the first `in_channels` channels of the output
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user