mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Cosmetic fixes
This commit is contained in:
parent
f8210e1d80
commit
308059edba
@ -105,7 +105,8 @@ class Zipformer(EncoderInterface):
|
||||
super(Zipformer, self).__init__()
|
||||
|
||||
def _to_tuple(x):
|
||||
""" Converts a single int or a 1-tuple of an int to a tuple with the same length as output_downsampling_factor"""
|
||||
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
|
||||
as downsampling_factor"""
|
||||
if isinstance(x, int):
|
||||
x = (x,)
|
||||
if len(x) == 1:
|
||||
@ -152,15 +153,15 @@ class Zipformer(EncoderInterface):
|
||||
num_encoders = len(downsampling_factor)
|
||||
for i in range(num_encoders):
|
||||
encoder_layer = ZipformerEncoderLayer(
|
||||
encoder_dim[i],
|
||||
pos_dim,
|
||||
num_heads[i],
|
||||
query_head_dim[i],
|
||||
pos_head_dim[i],
|
||||
value_head_dim[i],
|
||||
feedforward_dim[i],
|
||||
dropout,
|
||||
cnn_module_kernel[i],
|
||||
embed_dim=encoder_dim[i],
|
||||
pos_dim=pos_dim,
|
||||
num_heads=num_heads[i],
|
||||
query_head_dim=query_head_dim[i],
|
||||
pos_head_dim=pos_head_dim[i],
|
||||
value_head_dim=value_head_dim[i],
|
||||
feedforward_dim=feedforward_dim[i],
|
||||
dropout=dropout,
|
||||
cnn_module_kernel=cnn_module_kernel[i],
|
||||
)
|
||||
|
||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
||||
@ -168,8 +169,8 @@ class Zipformer(EncoderInterface):
|
||||
encoder = ZipformerEncoder(
|
||||
encoder_layer,
|
||||
num_encoder_layers[i],
|
||||
pos_dim,
|
||||
dropout,
|
||||
pos_dim=pos_dim,
|
||||
dropout=dropout,
|
||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1)
|
||||
)
|
||||
@ -372,8 +373,9 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
self.batch_count = 0
|
||||
|
||||
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
||||
embed_dim, pos_dim, num_heads,
|
||||
query_head_dim, pos_head_dim, dropout=0.0,
|
||||
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
|
||||
query_head_dim=query_head_dim, pos_head_dim=pos_head_dim,
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
self.self_attn1 = SelfAttention(embed_dim, num_heads,
|
||||
@ -1099,17 +1101,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
bias=False,
|
||||
initial_scale=0.05)
|
||||
|
||||
# this is to stop a failure mode where the output gets very small and is
|
||||
# dominated by the mean (the min_positive and max_positive will stop the mean
|
||||
# being much larger than the variance). Make min_abs very small because even for normal,
|
||||
# functional self_attn layers, the output rms can be very small.
|
||||
self.out_balancer = ActivationBalancer(embed_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.33,
|
||||
max_positive=0.66,
|
||||
min_abs=0.005, max_abs=1.0,
|
||||
min_prob=0.05)
|
||||
|
||||
|
||||
# the following are for diagnosics only, see --print-diagnostics option
|
||||
self.copy_pos_query = Identity()
|
||||
@ -1141,8 +1132,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
query_head_dim = self.query_head_dim
|
||||
pos_head_dim = self.pos_head_dim
|
||||
num_heads = self.num_heads
|
||||
dropout = self.dropout,
|
||||
training = self.training,
|
||||
|
||||
seq_len, batch_size, _ = x.shape
|
||||
|
||||
@ -1165,6 +1154,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
|
||||
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
||||
|
||||
# time1 refers to target, time2 refers to source.
|
||||
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
|
||||
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
|
||||
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
|
||||
@ -1189,7 +1179,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
attn_scores = torch.matmul(q, k) + pos_scores
|
||||
|
||||
if training and random.random() < 0.1:
|
||||
if self.training and random.random() < 0.1:
|
||||
# This is a harder way of limiting the attention scores to not be
|
||||
# too large. It incurs a penalty if any of them has an absolute
|
||||
# value greater than 50.0. this should be outside the normal range
|
||||
@ -1214,10 +1204,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
||||
|
||||
attn_scores = attn_scores.view(
|
||||
num_heads, batch_size, seq_len, seq_len
|
||||
)
|
||||
attn_scores = attn_scores.masked_fill(
|
||||
key_padding_mask.unsqueeze(1),
|
||||
float("-inf"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user